Ungraded Lab: Fully Convolutional Neural Networks for Image Segmentation

This notebook illustrates how to build a Fully Convolutional Neural Network for semantic image segmentation.

You will train the model on a custom dataset prepared by divamgupta. This contains video frames from a moving vehicle and is a subsample of the CamVid dataset.

You will be using a pretrained VGG-16 network for the feature extraction path, then followed by an FCN-8 network for upsampling and generating the predictions. The output will be a label map (i.e. segmentation mask) with predictions for 12 classes. Let's begin!

Imports

In [1]:
import os
import zipfile
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import numpy as np

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
import seaborn as sns

print("Tensorflow version " + tf.__version__)
Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.
Tensorflow version 2.12.0

Download the Dataset

We hosted the dataset in a Google bucket so you will need to download it first and unzip to a local directory.

In [2]:
# download the dataset (zipped file)
!gdown --id 0B0d9ZiqAgFkiOHR1NTJhWVJMNEU -O /tmp/fcnn-dataset.zip 
/usr/local/lib/python3.10/dist-packages/gdown/cli.py:121: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
  warnings.warn(
Downloading...
From: https://drive.google.com/uc?id=0B0d9ZiqAgFkiOHR1NTJhWVJMNEU
To: /tmp/fcnn-dataset.zip
100% 126M/126M [00:05<00:00, 23.2MB/s]

Troubleshooting: If you get a download error saying "Cannot retrieve the public link of the file.", please run the next two cells below to download the dataset. Otherwise, please skip them.

In [3]:
%%writefile download.sh

#!/bin/bash
fileid="0B0d9ZiqAgFkiOHR1NTJhWVJMNEU"
filename="/tmp/fcnn-dataset.zip"
html=`curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}"`
curl -Lb ./cookie "https://drive.google.com/uc?export=download&`echo ${html}|grep -Po '(confirm=[a-zA-Z0-9\-_]+)'`&id=${fileid}" -o ${filename}
Writing download.sh
In [4]:
# NOTE: Please only run this if downloading with gdown did not work.
# This will run the script created above.
!bash download.sh
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  119M  100  119M    0     0   106M      0  0:00:01  0:00:01 --:--:--  191M

You can extract the downloaded zip files with this code:

In [5]:
# extract the downloaded dataset to a local directory: /tmp/fcnn
local_zip = '/tmp/fcnn-dataset.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp/fcnn')
zip_ref.close()

The dataset you just downloaded contains folders for images and annotations. The images contain the video frames while the annotations contain the pixel-wise label maps. Each label map has the shape (height, width , 1) with each point in this space denoting the corresponding pixel's class. Classes are in the range [0, 11] (i.e. 12 classes) and the pixel labels correspond to these classes:

Value Class Name
0 sky
1 building
2 column/pole
3 road
4 side walk
5 vegetation
6 traffic light
7 fence
8 vehicle
9 pedestrian
10 byciclist
11 void

For example, if a pixel is part of a road, then that point will be labeled 3 in the label map. Run the cell below to create a list containing the class names:

  • Note: bicyclist is mispelled as 'byciclist' in the dataset. We won't handle data cleaning in this example, but you can inspect and clean the data if you want to use this as a starting point for a personal project.
In [6]:
# pixel labels in the video frames
class_names = ['sky', 'building','column/pole', 'road', 'side walk', 'vegetation', 'traffic light', 'fence', 'vehicle', 'pedestrian', 'byciclist', 'void']

Load and Prepare the Dataset

Next, you will load and prepare the train and validation sets for training. There are some preprocessing steps needed before the data is fed to the model. These include:

  • resizing the height and width of the input images and label maps (224 x 224px by default)
  • normalizing the input images' pixel values to fall in the range [-1, 1]
  • reshaping the label maps from (height, width, 1) to (height, width, 12) with each slice along the third axis having 1 if it belongs to the class corresponding to that slice's index else 0. For example, if a pixel is part of a road, then using the table above, that point at slice #3 will be labeled 1 and it will be 0 in all other slices. To illustrate using simple arrays:
# if we have a label map with 3 classes...
n_classes = 3
# and this is the original annotation...
orig_anno = [0 1 2]
# then the reshaped annotation will have 3 slices and its contents will look like this:
reshaped_anno = [1 0 0][0 1 0][0 0 1]

The following function will do the preprocessing steps mentioned above.

In [7]:
def map_filename_to_image_and_mask(t_filename, a_filename, height=224, width=224):
  '''
  Preprocesses the dataset by:
    * resizing the input image and label maps
    * normalizing the input image pixels
    * reshaping the label maps from (height, width, 1) to (height, width, 12)

  Args:
    t_filename (string) -- path to the raw input image
    a_filename (string) -- path to the raw annotation (label map) file
    height (int) -- height in pixels to resize to
    width (int) -- width in pixels to resize to

  Returns:
    image (tensor) -- preprocessed image
    annotation (tensor) -- preprocessed annotation
  '''

  # Convert image and mask files to tensors 
  img_raw = tf.io.read_file(t_filename)
  anno_raw = tf.io.read_file(a_filename)
  image = tf.image.decode_jpeg(img_raw)
  annotation = tf.image.decode_jpeg(anno_raw)
 
  # Resize image and segmentation mask
  image = tf.image.resize(image, (height, width,))
  annotation = tf.image.resize(annotation, (height, width,))
  image = tf.reshape(image, (height, width, 3,))
  annotation = tf.cast(annotation, dtype=tf.int32)
  annotation = tf.reshape(annotation, (height, width, 1,))
  stack_list = []

  # Reshape segmentation masks
  for c in range(len(class_names)):
    mask = tf.equal(annotation[:,:,0], tf.constant(c))
    stack_list.append(tf.cast(mask, dtype=tf.int32))
  
  annotation = tf.stack(stack_list, axis=2)

  # Normalize pixels in the input image
  image = image/127.5
  image -= 1

  return image, annotation

The dataset also already has separate folders for train and test sets. As described earlier, these sets will have two folders: one corresponding to the images, and the other containing the annotations.

In [8]:
# show folders inside the dataset you downloaded
!ls /tmp/fcnn/dataset1
annotations_prepped_test   images_prepped_test
annotations_prepped_train  images_prepped_train

You will use the following functions to create the tensorflow datasets from the images in these folders. Notice that before creating the batches in the get_training_dataset() and get_validation_set(), the images are first preprocessed using the map_filename_to_image_and_mask() function you defined earlier.

In [9]:
# Utilities for preparing the datasets

BATCH_SIZE = 64

def get_dataset_slice_paths(image_dir, label_map_dir):
  '''
  generates the lists of image and label map paths
  
  Args:
    image_dir (string) -- path to the input images directory
    label_map_dir (string) -- path to the label map directory

  Returns:
    image_paths (list of strings) -- paths to each image file
    label_map_paths (list of strings) -- paths to each label map
  '''
  image_file_list = os.listdir(image_dir)
  label_map_file_list = os.listdir(label_map_dir)
  image_paths = [os.path.join(image_dir, fname) for fname in image_file_list]
  label_map_paths = [os.path.join(label_map_dir, fname) for fname in label_map_file_list]

  return image_paths, label_map_paths


def get_training_dataset(image_paths, label_map_paths):
  '''
  Prepares shuffled batches of the training set.
  
  Args:
    image_paths (list of strings) -- paths to each image file in the train set
    label_map_paths (list of strings) -- paths to each label map in the train set

  Returns:
    tf Dataset containing the preprocessed train set
  '''
  training_dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_map_paths))
  training_dataset = training_dataset.map(map_filename_to_image_and_mask)
  training_dataset = training_dataset.shuffle(100, reshuffle_each_iteration=True)
  training_dataset = training_dataset.batch(BATCH_SIZE)
  training_dataset = training_dataset.repeat()
  training_dataset = training_dataset.prefetch(-1)

  return training_dataset


def get_validation_dataset(image_paths, label_map_paths):
  '''
  Prepares batches of the validation set.
  
  Args:
    image_paths (list of strings) -- paths to each image file in the val set
    label_map_paths (list of strings) -- paths to each label map in the val set

  Returns:
    tf Dataset containing the preprocessed validation set
  '''
  validation_dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_map_paths))
  validation_dataset = validation_dataset.map(map_filename_to_image_and_mask)
  validation_dataset = validation_dataset.batch(BATCH_SIZE)
  validation_dataset = validation_dataset.repeat()  

  return validation_dataset

You can now generate the training and validation sets by running the cell below.

In [10]:
# get the paths to the images
training_image_paths, training_label_map_paths = get_dataset_slice_paths('/tmp/fcnn/dataset1/images_prepped_train/','/tmp/fcnn/dataset1/annotations_prepped_train/')
validation_image_paths, validation_label_map_paths = get_dataset_slice_paths('/tmp/fcnn/dataset1/images_prepped_test/','/tmp/fcnn/dataset1/annotations_prepped_test/')

# generate the train and val sets
training_dataset = get_training_dataset(training_image_paths, training_label_map_paths)
validation_dataset = get_validation_dataset(validation_image_paths, validation_label_map_paths)

Let's Take a Look at the Dataset

You will also need utilities to help visualize the dataset and the model predictions later. First, you need to assign a color mapping to the classes in the label maps. Since our dataset has 12 classes, you need to have a list of 12 colors. We can use the color_palette() from Seaborn to generate this.

In [11]:
# generate a list that contains one color for each class
colors = sns.color_palette(None, len(class_names))

# print class name - normalized RGB tuple pairs
# the tuple values will be multiplied by 255 in the helper functions later
# to convert to the (0,0,0) to (255,255,255) RGB values you might be familiar with
for class_name, color in zip(class_names, colors):
  print(f'{class_name} -- {color}')
sky -- (0.12156862745098039, 0.4666666666666667, 0.7058823529411765)
building -- (1.0, 0.4980392156862745, 0.054901960784313725)
column/pole -- (0.17254901960784313, 0.6274509803921569, 0.17254901960784313)
road -- (0.8392156862745098, 0.15294117647058825, 0.1568627450980392)
side walk -- (0.5803921568627451, 0.403921568627451, 0.7411764705882353)
vegetation -- (0.5490196078431373, 0.33725490196078434, 0.29411764705882354)
traffic light -- (0.8901960784313725, 0.4666666666666667, 0.7607843137254902)
fence -- (0.4980392156862745, 0.4980392156862745, 0.4980392156862745)
vehicle -- (0.7372549019607844, 0.7411764705882353, 0.13333333333333333)
pedestrian -- (0.09019607843137255, 0.7450980392156863, 0.8117647058823529)
byciclist -- (0.12156862745098039, 0.4666666666666667, 0.7058823529411765)
void -- (1.0, 0.4980392156862745, 0.054901960784313725)
In [12]:
# Visualization Utilities

def fuse_with_pil(images):
  '''
  Creates a blank image and pastes input images

  Args:
    images (list of numpy arrays) - numpy array representations of the images to paste
  
  Returns:
    PIL Image object containing the images
  '''

  widths = (image.shape[1] for image in images)
  heights = (image.shape[0] for image in images)
  total_width = sum(widths)
  max_height = max(heights)

  new_im = PIL.Image.new('RGB', (total_width, max_height))

  x_offset = 0
  for im in images:
    pil_image = PIL.Image.fromarray(np.uint8(im))
    new_im.paste(pil_image, (x_offset,0))
    x_offset += im.shape[1]
  
  return new_im


def give_color_to_annotation(annotation):
  '''
  Converts a 2-D annotation to a numpy array with shape (height, width, 3) where
  the third axis represents the color channel. The label values are multiplied by
  255 and placed in this axis to give color to the annotation

  Args:
    annotation (numpy array) - label map array
  
  Returns:
    the annotation array with an additional color channel/axis
  '''
  seg_img = np.zeros( (annotation.shape[0],annotation.shape[1], 3) ).astype('float')
  
  for c in range(12):
    segc = (annotation == c)
    seg_img[:,:,0] += segc*( colors[c][0] * 255.0)
    seg_img[:,:,1] += segc*( colors[c][1] * 255.0)
    seg_img[:,:,2] += segc*( colors[c][2] * 255.0)
  
  return seg_img


def show_predictions(image, labelmaps, titles, iou_list, dice_score_list):
  '''
  Displays the images with the ground truth and predicted label maps

  Args:
    image (numpy array) -- the input image
    labelmaps (list of arrays) -- contains the predicted and ground truth label maps
    titles (list of strings) -- display headings for the images to be displayed
    iou_list (list of floats) -- the IOU values for each class
    dice_score_list (list of floats) -- the Dice Score for each vlass
  '''

  true_img = give_color_to_annotation(labelmaps[1])
  pred_img = give_color_to_annotation(labelmaps[0])

  image = image + 1
  image = image * 127.5
  images = np.uint8([image, pred_img, true_img])

  metrics_by_id = [(idx, iou, dice_score) for idx, (iou, dice_score) in enumerate(zip(iou_list, dice_score_list)) if iou > 0.0]
  metrics_by_id.sort(key=lambda tup: tup[1], reverse=True)  # sorts in place
  
  display_string_list = ["{}: IOU: {} Dice Score: {}".format(class_names[idx], iou, dice_score) for idx, iou, dice_score in metrics_by_id]
  display_string = "\n\n".join(display_string_list) 

  plt.figure(figsize=(15, 4))

  for idx, im in enumerate(images):
    plt.subplot(1, 3, idx+1)
    if idx == 1:
      plt.xlabel(display_string)
    plt.xticks([])
    plt.yticks([])
    plt.title(titles[idx], fontsize=12)
    plt.imshow(im)


def show_annotation_and_image(image, annotation):
  '''
  Displays the image and its annotation side by side

  Args:
    image (numpy array) -- the input image
    annotation (numpy array) -- the label map
  '''
  new_ann = np.argmax(annotation, axis=2)
  seg_img = give_color_to_annotation(new_ann)
  
  image = image + 1
  image = image * 127.5
  image = np.uint8(image)
  images = [image, seg_img]
  
  images = [image, seg_img]
  fused_img = fuse_with_pil(images)
  plt.imshow(fused_img)


def list_show_annotation(dataset):
  '''
  Displays images and its annotations side by side

  Args:
    dataset (tf Dataset) - batch of images and annotations
  '''

  ds = dataset.unbatch()
  ds = ds.shuffle(buffer_size=100)

  plt.figure(figsize=(25, 15))
  plt.title("Images And Annotations")
  plt.subplots_adjust(bottom=0.1, top=0.9, hspace=0.05)

  # we set the number of image-annotation pairs to 9
  # feel free to make this a function parameter if you want
  for idx, (image, annotation) in enumerate(ds.take(9)):
    plt.subplot(3, 3, idx + 1)
    plt.yticks([])
    plt.xticks([])
    show_annotation_and_image(image.numpy(), annotation.numpy())

 

Please run the cells below to see sample images from the train and validation sets. You will see the image and the label maps side side by side.

In [13]:
list_show_annotation(training_dataset)
<ipython-input-12-b1e70315b99d>:129: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
  plt.subplot(3, 3, idx + 1)
In [14]:
list_show_annotation(validation_dataset)
<ipython-input-12-b1e70315b99d>:129: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
  plt.subplot(3, 3, idx + 1)

Define the Model

You will now build the model and prepare it for training. AS mentioned earlier, this will use a VGG-16 network for the encoder and FCN-8 for the decoder. This is the diagram as shown in class:

fcn-8

For this exercise, you will notice a slight difference from the lecture because the dataset images are 224x224 instead of 32x32. You'll see how this is handled in the next cells as you build the encoder.

Define Pooling Block of VGG

As you saw in Course 1 of this specialization, VGG networks have repeating blocks so to make the code neat, it's best to create a function to encapsulate this process. Each block has convolutional layers followed by a max pooling layer which downsamples the image.

In [15]:
def block(x, n_convs, filters, kernel_size, activation, pool_size, pool_stride, block_name):
  '''
  Defines a block in the VGG network.

  Args:
    x (tensor) -- input image
    n_convs (int) -- number of convolution layers to append
    filters (int) -- number of filters for the convolution layers
    activation (string or object) -- activation to use in the convolution
    pool_size (int) -- size of the pooling layer
    pool_stride (int) -- stride of the pooling layer
    block_name (string) -- name of the block

  Returns:
    tensor containing the max-pooled output of the convolutions
  '''

  for i in range(n_convs):
      x = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding='same', name="{}_conv{}".format(block_name, i + 1))(x)
    
  x = tf.keras.layers.MaxPooling2D(pool_size=pool_size, strides=pool_stride, name="{}_pool{}".format(block_name, i+1 ))(x)

  return x

Download VGG weights

First, please run the cell below to get pre-trained weights for VGG-16. You will load this in the next section when you build the encoder network.

In [16]:
# download the weights
!wget https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5

# assign to a variable
vgg_weights_path = "vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
--2023-04-28 09:50:22--  https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/64878964/b09fedd4-5983-11e6-8f9f-904ea400969a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230428%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230428T095022Z&X-Amz-Expires=300&X-Amz-Signature=0710c6e1d267a4630bcc54be20741846582fea4d6e2698ab26bde8c0697b53c4&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=64878964&response-content-disposition=attachment%3B%20filename%3Dvgg16_weights_tf_dim_ordering_tf_kernels_notop.h5&response-content-type=application%2Foctet-stream [following]
--2023-04-28 09:50:22--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/64878964/b09fedd4-5983-11e6-8f9f-904ea400969a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230428%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230428T095022Z&X-Amz-Expires=300&X-Amz-Signature=0710c6e1d267a4630bcc54be20741846582fea4d6e2698ab26bde8c0697b53c4&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=64878964&response-content-disposition=attachment%3B%20filename%3Dvgg16_weights_tf_dim_ordering_tf_kernels_notop.h5&response-content-type=application%2Foctet-stream
Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58889256 (56M) [application/octet-stream]
Saving to: ‘vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5’

vgg16_weights_tf_di 100%[===================>]  56.16M  98.7MB/s    in 0.6s    

2023-04-28 09:50:23 (98.7 MB/s) - ‘vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5’ saved [58889256/58889256]

Define VGG-16

You can build the encoder as shown below.

  • You will create 5 blocks with increasing number of filters at each stage.
  • The number of convolutions, filters, kernel size, activation, pool size and pool stride will remain constant.
  • You will load the pretrained weights after creating the VGG 16 network.
  • Additional convolution layers will be appended to extract more features.
  • The output will contain the output of the last layer and the previous four convolution blocks.
In [17]:
def VGG_16(image_input):
  '''
  This function defines the VGG encoder.

  Args:
    image_input (tensor) - batch of images

  Returns:
    tuple of tensors - output of all encoder blocks plus the final convolution layer
  '''

  # create 5 blocks with increasing filters at each stage. 
  # you will save the output of each block (i.e. p1, p2, p3, p4, p5). "p" stands for the pooling layer.
  x = block(image_input,n_convs=2, filters=64, kernel_size=(3,3), activation='relu',pool_size=(2,2), pool_stride=(2,2), block_name='block1')
  p1= x

  x = block(x,n_convs=2, filters=128, kernel_size=(3,3), activation='relu',pool_size=(2,2), pool_stride=(2,2), block_name='block2')
  p2 = x

  x = block(x,n_convs=3, filters=256, kernel_size=(3,3), activation='relu',pool_size=(2,2), pool_stride=(2,2), block_name='block3')
  p3 = x

  x = block(x,n_convs=3, filters=512, kernel_size=(3,3), activation='relu',pool_size=(2,2), pool_stride=(2,2), block_name='block4')
  p4 = x

  x = block(x,n_convs=3, filters=512, kernel_size=(3,3), activation='relu',pool_size=(2,2), pool_stride=(2,2), block_name='block5')
  p5 = x

  # create the vgg model
  vgg  = tf.keras.Model(image_input , p5)

  # load the pretrained weights you downloaded earlier
  vgg.load_weights(vgg_weights_path) 

  # number of filters for the output convolutional layers
  n = 4096

  # our input images are 224x224 pixels so they will be downsampled to 7x7 after the pooling layers above.
  # we can extract more features by chaining two more convolution layers.
  c6 = tf.keras.layers.Conv2D( n , ( 7 , 7 ) , activation='relu' , padding='same', name="conv6")(p5)
  c7 = tf.keras.layers.Conv2D( n , ( 1 , 1 ) , activation='relu' , padding='same', name="conv7")(c6)

  # return the outputs at each stage. you will only need two of these in this particular exercise 
  # but we included it all in case you want to experiment with other types of decoders.
  return (p1, p2, p3, p4, c7)

Define FCN 8 Decoder

Next, you will build the decoder using deconvolution layers. Please refer to the diagram for FCN-8 at the start of this section to visualize what the code below is doing. It will involve two summations before upsampling to the original image size and generating the predicted mask.

In [18]:
def fcn8_decoder(convs, n_classes):
  '''
  Defines the FCN 8 decoder.

  Args:
    convs (tuple of tensors) - output of the encoder network
    n_classes (int) - number of classes

  Returns:
    tensor with shape (height, width, n_classes) containing class probabilities
  '''

  # unpack the output of the encoder
  f1, f2, f3, f4, f5 = convs
  
  # upsample the output of the encoder then crop extra pixels that were introduced
  o = tf.keras.layers.Conv2DTranspose(n_classes , kernel_size=(4,4) ,  strides=(2,2) , use_bias=False )(f5)
  o = tf.keras.layers.Cropping2D(cropping=(1,1))(o)

  # load the pool 4 prediction and do a 1x1 convolution to reshape it to the same shape of `o` above
  o2 = f4
  o2 = ( tf.keras.layers.Conv2D(n_classes , ( 1 , 1 ) , activation='relu' , padding='same'))(o2)

  # add the results of the upsampling and pool 4 prediction
  o = tf.keras.layers.Add()([o, o2])

  # upsample the resulting tensor of the operation you just did
  o = (tf.keras.layers.Conv2DTranspose( n_classes , kernel_size=(4,4) ,  strides=(2,2) , use_bias=False ))(o)
  o = tf.keras.layers.Cropping2D(cropping=(1, 1))(o)

  # load the pool 3 prediction and do a 1x1 convolution to reshape it to the same shape of `o` above
  o2 = f3
  o2 = ( tf.keras.layers.Conv2D(n_classes , ( 1 , 1 ) , activation='relu' , padding='same'))(o2)

  # add the results of the upsampling and pool 3 prediction
  o = tf.keras.layers.Add()([o, o2])
  
  # upsample up to the size of the original image
  o = tf.keras.layers.Conv2DTranspose(n_classes , kernel_size=(8,8) ,  strides=(8,8) , use_bias=False )(o)

  # append a softmax to get the class probabilities
  o = (tf.keras.layers.Activation('softmax'))(o)

  return o

Define Final Model

You can now build the final model by connecting the encoder and decoder blocks.

In [19]:
def segmentation_model():
  '''
  Defines the final segmentation model by chaining together the encoder and decoder.

  Returns:
    keras Model that connects the encoder and decoder networks of the segmentation model
  '''
  
  inputs = tf.keras.layers.Input(shape=(224,224,3,))
  convs = VGG_16(image_input=inputs)
  outputs = fcn8_decoder(convs, 12)
  model = tf.keras.Model(inputs=inputs, outputs=outputs)
  
  return model
In [20]:
# instantiate the model and see how it looks
model = segmentation_model()
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 block1_conv1 (Conv2D)          (None, 224, 224, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 block1_conv2 (Conv2D)          (None, 224, 224, 64  36928       ['block1_conv1[0][0]']           
                                )                                                                 
                                                                                                  
 block1_pool2 (MaxPooling2D)    (None, 112, 112, 64  0           ['block1_conv2[0][0]']           
                                )                                                                 
                                                                                                  
 block2_conv1 (Conv2D)          (None, 112, 112, 12  73856       ['block1_pool2[0][0]']           
                                8)                                                                
                                                                                                  
 block2_conv2 (Conv2D)          (None, 112, 112, 12  147584      ['block2_conv1[0][0]']           
                                8)                                                                
                                                                                                  
 block2_pool2 (MaxPooling2D)    (None, 56, 56, 128)  0           ['block2_conv2[0][0]']           
                                                                                                  
 block3_conv1 (Conv2D)          (None, 56, 56, 256)  295168      ['block2_pool2[0][0]']           
                                                                                                  
 block3_conv2 (Conv2D)          (None, 56, 56, 256)  590080      ['block3_conv1[0][0]']           
                                                                                                  
 block3_conv3 (Conv2D)          (None, 56, 56, 256)  590080      ['block3_conv2[0][0]']           
                                                                                                  
 block3_pool3 (MaxPooling2D)    (None, 28, 28, 256)  0           ['block3_conv3[0][0]']           
                                                                                                  
 block4_conv1 (Conv2D)          (None, 28, 28, 512)  1180160     ['block3_pool3[0][0]']           
                                                                                                  
 block4_conv2 (Conv2D)          (None, 28, 28, 512)  2359808     ['block4_conv1[0][0]']           
                                                                                                  
 block4_conv3 (Conv2D)          (None, 28, 28, 512)  2359808     ['block4_conv2[0][0]']           
                                                                                                  
 block4_pool3 (MaxPooling2D)    (None, 14, 14, 512)  0           ['block4_conv3[0][0]']           
                                                                                                  
 block5_conv1 (Conv2D)          (None, 14, 14, 512)  2359808     ['block4_pool3[0][0]']           
                                                                                                  
 block5_conv2 (Conv2D)          (None, 14, 14, 512)  2359808     ['block5_conv1[0][0]']           
                                                                                                  
 block5_conv3 (Conv2D)          (None, 14, 14, 512)  2359808     ['block5_conv2[0][0]']           
                                                                                                  
 block5_pool3 (MaxPooling2D)    (None, 7, 7, 512)    0           ['block5_conv3[0][0]']           
                                                                                                  
 conv6 (Conv2D)                 (None, 7, 7, 4096)   102764544   ['block5_pool3[0][0]']           
                                                                                                  
 conv7 (Conv2D)                 (None, 7, 7, 4096)   16781312    ['conv6[0][0]']                  
                                                                                                  
 conv2d_transpose (Conv2DTransp  (None, 16, 16, 12)  786432      ['conv7[0][0]']                  
 ose)                                                                                             
                                                                                                  
 cropping2d (Cropping2D)        (None, 14, 14, 12)   0           ['conv2d_transpose[0][0]']       
                                                                                                  
 conv2d (Conv2D)                (None, 14, 14, 12)   6156        ['block4_pool3[0][0]']           
                                                                                                  
 add (Add)                      (None, 14, 14, 12)   0           ['cropping2d[0][0]',             
                                                                  'conv2d[0][0]']                 
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 30, 30, 12)  2304        ['add[0][0]']                    
 spose)                                                                                           
                                                                                                  
 cropping2d_1 (Cropping2D)      (None, 28, 28, 12)   0           ['conv2d_transpose_1[0][0]']     
                                                                                                  
 conv2d_1 (Conv2D)              (None, 28, 28, 12)   3084        ['block3_pool3[0][0]']           
                                                                                                  
 add_1 (Add)                    (None, 28, 28, 12)   0           ['cropping2d_1[0][0]',           
                                                                  'conv2d_1[0][0]']               
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 224, 224, 12  9216       ['add_1[0][0]']                  
 spose)                         )                                                                 
                                                                                                  
 activation (Activation)        (None, 224, 224, 12  0           ['conv2d_transpose_2[0][0]']     
                                )                                                                 
                                                                                                  
==================================================================================================
Total params: 135,067,736
Trainable params: 135,067,736
Non-trainable params: 0
__________________________________________________________________________________________________

Compile the Model

Next, the model will be configured for training. You will need to specify the loss, optimizer and metrics. You will use categorical_crossentropy as the loss function since the label map is transformed to one hot encoded vectors for each pixel in the image (i.e. 1 in one slice and 0 for other slices as described earlier).

In [ ]:
sgd = tf.keras.optimizers.SGD(learning_rate=1E-2, momentum=0.9, nesterov=True)

model.compile(loss='categorical_crossentropy',
              optimizer=sgd,
              metrics=['accuracy'])

Train the Model

The model can now be trained. This will take around 30 minutes to run and you will reach around 85% accuracy for both train and val sets.

In [22]:
# number of training images
train_count = 367

# number of validation images
validation_count = 101

EPOCHS = 170

steps_per_epoch = train_count//BATCH_SIZE
validation_steps = validation_count//BATCH_SIZE

history = model.fit(training_dataset,
                    steps_per_epoch=steps_per_epoch, validation_data=validation_dataset, validation_steps=validation_steps, epochs=EPOCHS)
Epoch 1/170
5/5 [==============================] - 57s 2s/step - loss: 2.7915 - accuracy: 0.0872 - val_loss: 2.4856 - val_accuracy: 0.0917
Epoch 2/170
5/5 [==============================] - 35s 2s/step - loss: 2.4843 - accuracy: 0.0895 - val_loss: 2.4827 - val_accuracy: 0.0928
Epoch 3/170
5/5 [==============================] - 9s 2s/step - loss: 2.4825 - accuracy: 0.0909 - val_loss: 2.4816 - val_accuracy: 0.0924
Epoch 4/170
5/5 [==============================] - 11s 2s/step - loss: 2.4804 - accuracy: 0.0933 - val_loss: 2.4773 - val_accuracy: 0.0966
Epoch 5/170
5/5 [==============================] - 9s 2s/step - loss: 2.4752 - accuracy: 0.1010 - val_loss: 2.4722 - val_accuracy: 0.1035
Epoch 6/170
5/5 [==============================] - 9s 2s/step - loss: 2.4684 - accuracy: 0.1088 - val_loss: 2.4639 - val_accuracy: 0.1120
Epoch 7/170
5/5 [==============================] - 11s 2s/step - loss: 2.4566 - accuracy: 0.1230 - val_loss: 2.4480 - val_accuracy: 0.1305
Epoch 8/170
5/5 [==============================] - 9s 2s/step - loss: 2.4346 - accuracy: 0.1436 - val_loss: 2.4166 - val_accuracy: 0.1546
Epoch 9/170
5/5 [==============================] - 9s 2s/step - loss: 2.3886 - accuracy: 0.1764 - val_loss: 2.3530 - val_accuracy: 0.1948
Epoch 10/170
5/5 [==============================] - 9s 2s/step - loss: 2.3008 - accuracy: 0.2255 - val_loss: 2.2755 - val_accuracy: 0.2427
Epoch 11/170
5/5 [==============================] - 9s 2s/step - loss: 2.1896 - accuracy: 0.2780 - val_loss: 2.1333 - val_accuracy: 0.2927
Epoch 12/170
5/5 [==============================] - 9s 2s/step - loss: 2.0880 - accuracy: 0.3212 - val_loss: 1.9942 - val_accuracy: 0.3070
Epoch 13/170
5/5 [==============================] - 9s 2s/step - loss: 2.0033 - accuracy: 0.3424 - val_loss: 1.9628 - val_accuracy: 0.3110
Epoch 14/170
5/5 [==============================] - 9s 2s/step - loss: 1.9594 - accuracy: 0.3445 - val_loss: 1.8912 - val_accuracy: 0.3113
Epoch 15/170
5/5 [==============================] - 10s 2s/step - loss: 1.8513 - accuracy: 0.3503 - val_loss: 1.8639 - val_accuracy: 0.3424
Epoch 16/170
5/5 [==============================] - 9s 2s/step - loss: 1.8735 - accuracy: 0.3805 - val_loss: 1.8794 - val_accuracy: 0.3705
Epoch 17/170
5/5 [==============================] - 9s 2s/step - loss: 1.7481 - accuracy: 0.4108 - val_loss: 1.7947 - val_accuracy: 0.4034
Epoch 18/170
5/5 [==============================] - 9s 2s/step - loss: 1.6509 - accuracy: 0.4694 - val_loss: 1.7930 - val_accuracy: 0.4526
Epoch 19/170
5/5 [==============================] - 9s 2s/step - loss: 1.7707 - accuracy: 0.4566 - val_loss: 1.9137 - val_accuracy: 0.4175
Epoch 20/170
5/5 [==============================] - 9s 2s/step - loss: 1.7560 - accuracy: 0.4674 - val_loss: 1.8042 - val_accuracy: 0.3894
Epoch 21/170
5/5 [==============================] - 9s 2s/step - loss: 1.5428 - accuracy: 0.5309 - val_loss: 1.7298 - val_accuracy: 0.5143
Epoch 22/170
5/5 [==============================] - 9s 2s/step - loss: 1.4324 - accuracy: 0.5911 - val_loss: 1.5431 - val_accuracy: 0.5290
Epoch 23/170
5/5 [==============================] - 9s 2s/step - loss: 1.2976 - accuracy: 0.6356 - val_loss: 1.4631 - val_accuracy: 0.5690
Epoch 24/170
5/5 [==============================] - 9s 2s/step - loss: 1.2280 - accuracy: 0.6570 - val_loss: 1.3941 - val_accuracy: 0.5921
Epoch 25/170
5/5 [==============================] - 9s 2s/step - loss: 1.1797 - accuracy: 0.6608 - val_loss: 1.3477 - val_accuracy: 0.5992
Epoch 26/170
5/5 [==============================] - 10s 2s/step - loss: 1.1221 - accuracy: 0.6695 - val_loss: 1.2823 - val_accuracy: 0.6025
Epoch 27/170
5/5 [==============================] - 9s 2s/step - loss: 1.0920 - accuracy: 0.6720 - val_loss: 1.3010 - val_accuracy: 0.6024
Epoch 28/170
5/5 [==============================] - 9s 2s/step - loss: 1.1230 - accuracy: 0.6656 - val_loss: 1.2672 - val_accuracy: 0.5875
Epoch 29/170
5/5 [==============================] - 9s 2s/step - loss: 1.0855 - accuracy: 0.6721 - val_loss: 1.2027 - val_accuracy: 0.6074
Epoch 30/170
5/5 [==============================] - 9s 2s/step - loss: 1.0158 - accuracy: 0.6850 - val_loss: 1.1772 - val_accuracy: 0.6084
Epoch 31/170
5/5 [==============================] - 9s 2s/step - loss: 1.0152 - accuracy: 0.6806 - val_loss: 1.1653 - val_accuracy: 0.6106
Epoch 32/170
5/5 [==============================] - 9s 2s/step - loss: 1.0052 - accuracy: 0.6817 - val_loss: 1.1724 - val_accuracy: 0.6003
Epoch 33/170
5/5 [==============================] - 9s 2s/step - loss: 0.9937 - accuracy: 0.6843 - val_loss: 1.1262 - val_accuracy: 0.6178
Epoch 34/170
5/5 [==============================] - 9s 2s/step - loss: 0.9854 - accuracy: 0.6823 - val_loss: 1.1831 - val_accuracy: 0.5927
Epoch 35/170
5/5 [==============================] - 9s 2s/step - loss: 0.9873 - accuracy: 0.6848 - val_loss: 1.1084 - val_accuracy: 0.6125
Epoch 36/170
5/5 [==============================] - 9s 2s/step - loss: 0.9368 - accuracy: 0.6949 - val_loss: 1.0966 - val_accuracy: 0.6158
Epoch 37/170
5/5 [==============================] - 9s 2s/step - loss: 0.9563 - accuracy: 0.6901 - val_loss: 1.0763 - val_accuracy: 0.6151
Epoch 38/170
5/5 [==============================] - 9s 2s/step - loss: 0.9319 - accuracy: 0.6924 - val_loss: 1.0650 - val_accuracy: 0.6186
Epoch 39/170
5/5 [==============================] - 10s 2s/step - loss: 1.0443 - accuracy: 0.6718 - val_loss: 1.1862 - val_accuracy: 0.5853
Epoch 40/170
5/5 [==============================] - 9s 2s/step - loss: 0.9680 - accuracy: 0.6868 - val_loss: 1.1274 - val_accuracy: 0.6139
Epoch 41/170
5/5 [==============================] - 9s 2s/step - loss: 0.9435 - accuracy: 0.6898 - val_loss: 1.0548 - val_accuracy: 0.6157
Epoch 42/170
5/5 [==============================] - 9s 2s/step - loss: 0.9218 - accuracy: 0.6955 - val_loss: 1.0467 - val_accuracy: 0.6164
Epoch 43/170
5/5 [==============================] - 9s 2s/step - loss: 0.9056 - accuracy: 0.7022 - val_loss: 1.0289 - val_accuracy: 0.6188
Epoch 44/170
5/5 [==============================] - 9s 2s/step - loss: 0.9115 - accuracy: 0.6999 - val_loss: 1.0418 - val_accuracy: 0.6196
Epoch 45/170
5/5 [==============================] - 10s 2s/step - loss: 0.8891 - accuracy: 0.7040 - val_loss: 1.0096 - val_accuracy: 0.6282
Epoch 46/170
5/5 [==============================] - 9s 2s/step - loss: 0.8704 - accuracy: 0.7161 - val_loss: 1.0147 - val_accuracy: 0.6335
Epoch 47/170
5/5 [==============================] - 9s 2s/step - loss: 0.8744 - accuracy: 0.7157 - val_loss: 1.0596 - val_accuracy: 0.6382
Epoch 48/170
5/5 [==============================] - 9s 2s/step - loss: 0.8662 - accuracy: 0.7182 - val_loss: 0.9737 - val_accuracy: 0.6521
Epoch 49/170
5/5 [==============================] - 9s 2s/step - loss: 0.8471 - accuracy: 0.7235 - val_loss: 0.9540 - val_accuracy: 0.6645
Epoch 50/170
5/5 [==============================] - 9s 2s/step - loss: 0.8572 - accuracy: 0.7240 - val_loss: 0.9368 - val_accuracy: 0.6697
Epoch 51/170
5/5 [==============================] - 9s 2s/step - loss: 0.8590 - accuracy: 0.7292 - val_loss: 0.9974 - val_accuracy: 0.6772
Epoch 52/170
5/5 [==============================] - 9s 2s/step - loss: 0.8341 - accuracy: 0.7346 - val_loss: 0.8996 - val_accuracy: 0.7090
Epoch 53/170
5/5 [==============================] - 9s 2s/step - loss: 0.8129 - accuracy: 0.7459 - val_loss: 0.8874 - val_accuracy: 0.7233
Epoch 54/170
5/5 [==============================] - 9s 2s/step - loss: 0.7936 - accuracy: 0.7534 - val_loss: 0.8669 - val_accuracy: 0.7329
Epoch 55/170
5/5 [==============================] - 9s 2s/step - loss: 0.7959 - accuracy: 0.7542 - val_loss: 0.8569 - val_accuracy: 0.7241
Epoch 56/170
5/5 [==============================] - 9s 2s/step - loss: 0.9321 - accuracy: 0.7097 - val_loss: 0.9847 - val_accuracy: 0.6570
Epoch 57/170
5/5 [==============================] - 9s 2s/step - loss: 0.8085 - accuracy: 0.7445 - val_loss: 0.8604 - val_accuracy: 0.7400
Epoch 58/170
5/5 [==============================] - 9s 2s/step - loss: 0.7855 - accuracy: 0.7565 - val_loss: 0.8301 - val_accuracy: 0.7450
Epoch 59/170
5/5 [==============================] - 9s 2s/step - loss: 0.7655 - accuracy: 0.7648 - val_loss: 0.8274 - val_accuracy: 0.7459
Epoch 60/170
5/5 [==============================] - 9s 2s/step - loss: 0.7543 - accuracy: 0.7701 - val_loss: 0.8162 - val_accuracy: 0.7495
Epoch 61/170
5/5 [==============================] - 9s 2s/step - loss: 0.7504 - accuracy: 0.7716 - val_loss: 0.8115 - val_accuracy: 0.7505
Epoch 62/170
5/5 [==============================] - 9s 2s/step - loss: 0.7771 - accuracy: 0.7584 - val_loss: 0.8755 - val_accuracy: 0.7279
Epoch 63/170
5/5 [==============================] - 9s 2s/step - loss: 0.7411 - accuracy: 0.7757 - val_loss: 0.8083 - val_accuracy: 0.7527
Epoch 64/170
5/5 [==============================] - 9s 2s/step - loss: 0.7322 - accuracy: 0.7765 - val_loss: 0.8033 - val_accuracy: 0.7536
Epoch 65/170
5/5 [==============================] - 9s 2s/step - loss: 0.7317 - accuracy: 0.7763 - val_loss: 0.7735 - val_accuracy: 0.7611
Epoch 66/170
5/5 [==============================] - 9s 2s/step - loss: 0.7094 - accuracy: 0.7835 - val_loss: 0.7783 - val_accuracy: 0.7619
Epoch 67/170
5/5 [==============================] - 9s 2s/step - loss: 0.7346 - accuracy: 0.7790 - val_loss: 0.8277 - val_accuracy: 0.7459
Epoch 68/170
5/5 [==============================] - 9s 2s/step - loss: 0.7276 - accuracy: 0.7790 - val_loss: 0.8562 - val_accuracy: 0.7390
Epoch 69/170
5/5 [==============================] - 10s 2s/step - loss: 0.7384 - accuracy: 0.7751 - val_loss: 0.7644 - val_accuracy: 0.7642
Epoch 70/170
5/5 [==============================] - 9s 2s/step - loss: 0.7007 - accuracy: 0.7860 - val_loss: 0.7981 - val_accuracy: 0.7471
Epoch 71/170
5/5 [==============================] - 9s 2s/step - loss: 0.7183 - accuracy: 0.7796 - val_loss: 0.7661 - val_accuracy: 0.7655
Epoch 72/170
5/5 [==============================] - 9s 2s/step - loss: 0.6951 - accuracy: 0.7884 - val_loss: 0.7512 - val_accuracy: 0.7664
Epoch 73/170
5/5 [==============================] - 9s 2s/step - loss: 0.6902 - accuracy: 0.7896 - val_loss: 0.7645 - val_accuracy: 0.7661
Epoch 74/170
5/5 [==============================] - 9s 2s/step - loss: 0.6972 - accuracy: 0.7888 - val_loss: 0.7492 - val_accuracy: 0.7675
Epoch 75/170
5/5 [==============================] - 9s 2s/step - loss: 0.6748 - accuracy: 0.7964 - val_loss: 0.7401 - val_accuracy: 0.7727
Epoch 76/170
5/5 [==============================] - 9s 2s/step - loss: 0.6774 - accuracy: 0.7957 - val_loss: 0.7618 - val_accuracy: 0.7612
Epoch 77/170
5/5 [==============================] - 9s 2s/step - loss: 0.7118 - accuracy: 0.7818 - val_loss: 1.1362 - val_accuracy: 0.6465
Epoch 78/170
5/5 [==============================] - 9s 2s/step - loss: 0.8143 - accuracy: 0.7484 - val_loss: 0.7727 - val_accuracy: 0.7635
Epoch 79/170
5/5 [==============================] - 9s 2s/step - loss: 0.6996 - accuracy: 0.7887 - val_loss: 0.7593 - val_accuracy: 0.7619
Epoch 80/170
5/5 [==============================] - 9s 2s/step - loss: 0.6767 - accuracy: 0.7937 - val_loss: 0.7278 - val_accuracy: 0.7711
Epoch 81/170
5/5 [==============================] - 9s 2s/step - loss: 0.6709 - accuracy: 0.7958 - val_loss: 0.7321 - val_accuracy: 0.7706
Epoch 82/170
5/5 [==============================] - 9s 2s/step - loss: 0.6548 - accuracy: 0.8014 - val_loss: 0.7473 - val_accuracy: 0.7677
Epoch 83/170
5/5 [==============================] - 9s 2s/step - loss: 0.6627 - accuracy: 0.7997 - val_loss: 0.7327 - val_accuracy: 0.7720
Epoch 84/170
5/5 [==============================] - 9s 2s/step - loss: 0.6515 - accuracy: 0.8028 - val_loss: 0.7417 - val_accuracy: 0.7702
Epoch 85/170
5/5 [==============================] - 9s 2s/step - loss: 0.6496 - accuracy: 0.8046 - val_loss: 0.7221 - val_accuracy: 0.7761
Epoch 86/170
5/5 [==============================] - 9s 2s/step - loss: 0.6390 - accuracy: 0.8085 - val_loss: 0.7149 - val_accuracy: 0.7798
Epoch 87/170
5/5 [==============================] - 9s 2s/step - loss: 0.6326 - accuracy: 0.8104 - val_loss: 0.7140 - val_accuracy: 0.7788
Epoch 88/170
5/5 [==============================] - 9s 2s/step - loss: 0.6734 - accuracy: 0.7969 - val_loss: 0.8228 - val_accuracy: 0.7443
Epoch 89/170
5/5 [==============================] - 9s 2s/step - loss: 0.6545 - accuracy: 0.8023 - val_loss: 0.7175 - val_accuracy: 0.7784
Epoch 90/170
5/5 [==============================] - 9s 2s/step - loss: 0.6271 - accuracy: 0.8115 - val_loss: 0.7003 - val_accuracy: 0.7822
Epoch 91/170
5/5 [==============================] - 9s 2s/step - loss: 0.6331 - accuracy: 0.8093 - val_loss: 0.6922 - val_accuracy: 0.7841
Epoch 92/170
5/5 [==============================] - 9s 2s/step - loss: 0.6234 - accuracy: 0.8134 - val_loss: 0.6861 - val_accuracy: 0.7880
Epoch 93/170
5/5 [==============================] - 9s 2s/step - loss: 0.6426 - accuracy: 0.8079 - val_loss: 0.7424 - val_accuracy: 0.7700
Epoch 94/170
5/5 [==============================] - 9s 2s/step - loss: 0.6184 - accuracy: 0.8143 - val_loss: 0.6895 - val_accuracy: 0.7874
Epoch 95/170
5/5 [==============================] - 9s 2s/step - loss: 0.6207 - accuracy: 0.8141 - val_loss: 0.6969 - val_accuracy: 0.7866
Epoch 96/170
5/5 [==============================] - 9s 2s/step - loss: 0.6124 - accuracy: 0.8174 - val_loss: 0.7178 - val_accuracy: 0.7788
Epoch 97/170
5/5 [==============================] - 9s 2s/step - loss: 0.6173 - accuracy: 0.8156 - val_loss: 0.6766 - val_accuracy: 0.7911
Epoch 98/170
5/5 [==============================] - 9s 2s/step - loss: 0.6193 - accuracy: 0.8153 - val_loss: 0.7049 - val_accuracy: 0.7835
Epoch 99/170
5/5 [==============================] - 9s 2s/step - loss: 0.6141 - accuracy: 0.8180 - val_loss: 0.6639 - val_accuracy: 0.7986
Epoch 100/170
5/5 [==============================] - 9s 2s/step - loss: 0.6151 - accuracy: 0.8171 - val_loss: 0.6822 - val_accuracy: 0.7898
Epoch 101/170
5/5 [==============================] - 9s 2s/step - loss: 0.6045 - accuracy: 0.8200 - val_loss: 0.6837 - val_accuracy: 0.7922
Epoch 102/170
5/5 [==============================] - 10s 2s/step - loss: 0.5924 - accuracy: 0.8245 - val_loss: 0.6540 - val_accuracy: 0.8023
Epoch 103/170
5/5 [==============================] - 9s 2s/step - loss: 0.6043 - accuracy: 0.8215 - val_loss: 0.7233 - val_accuracy: 0.7833
Epoch 104/170
5/5 [==============================] - 9s 2s/step - loss: 0.6078 - accuracy: 0.8221 - val_loss: 0.6578 - val_accuracy: 0.8031
Epoch 105/170
5/5 [==============================] - 10s 2s/step - loss: 0.5864 - accuracy: 0.8265 - val_loss: 0.6615 - val_accuracy: 0.8012
Epoch 106/170
5/5 [==============================] - 9s 2s/step - loss: 0.5953 - accuracy: 0.8247 - val_loss: 0.6532 - val_accuracy: 0.8015
Epoch 107/170
5/5 [==============================] - 9s 2s/step - loss: 0.5936 - accuracy: 0.8236 - val_loss: 0.7386 - val_accuracy: 0.7784
Epoch 108/170
5/5 [==============================] - 9s 2s/step - loss: 0.6004 - accuracy: 0.8229 - val_loss: 0.6706 - val_accuracy: 0.8020
Epoch 109/170
5/5 [==============================] - 9s 2s/step - loss: 0.5850 - accuracy: 0.8282 - val_loss: 0.6619 - val_accuracy: 0.8040
Epoch 110/170
5/5 [==============================] - 9s 2s/step - loss: 0.6092 - accuracy: 0.8211 - val_loss: 0.6724 - val_accuracy: 0.8067
Epoch 111/170
5/5 [==============================] - 9s 2s/step - loss: 0.5919 - accuracy: 0.8280 - val_loss: 0.6645 - val_accuracy: 0.8018
Epoch 112/170
5/5 [==============================] - 9s 2s/step - loss: 0.5819 - accuracy: 0.8296 - val_loss: 0.6249 - val_accuracy: 0.8193
Epoch 113/170
5/5 [==============================] - 9s 2s/step - loss: 0.5657 - accuracy: 0.8342 - val_loss: 0.6357 - val_accuracy: 0.8135
Epoch 114/170
5/5 [==============================] - 9s 2s/step - loss: 0.5609 - accuracy: 0.8359 - val_loss: 0.6405 - val_accuracy: 0.8116
Epoch 115/170
5/5 [==============================] - 9s 2s/step - loss: 0.5730 - accuracy: 0.8322 - val_loss: 0.6676 - val_accuracy: 0.8073
Epoch 116/170
5/5 [==============================] - 9s 2s/step - loss: 0.5816 - accuracy: 0.8300 - val_loss: 0.6299 - val_accuracy: 0.8154
Epoch 117/170
5/5 [==============================] - 10s 2s/step - loss: 0.5676 - accuracy: 0.8352 - val_loss: 0.6311 - val_accuracy: 0.8165
Epoch 118/170
5/5 [==============================] - 9s 2s/step - loss: 0.5612 - accuracy: 0.8355 - val_loss: 0.6293 - val_accuracy: 0.8197
Epoch 119/170
5/5 [==============================] - 9s 2s/step - loss: 0.5825 - accuracy: 0.8318 - val_loss: 0.6236 - val_accuracy: 0.8255
Epoch 120/170
5/5 [==============================] - 9s 2s/step - loss: 0.5901 - accuracy: 0.8279 - val_loss: 0.6478 - val_accuracy: 0.8063
Epoch 121/170
5/5 [==============================] - 11s 2s/step - loss: 0.5619 - accuracy: 0.8355 - val_loss: 0.6154 - val_accuracy: 0.8239
Epoch 122/170
5/5 [==============================] - 9s 2s/step - loss: 0.5457 - accuracy: 0.8428 - val_loss: 0.6112 - val_accuracy: 0.8248
Epoch 123/170
5/5 [==============================] - 9s 2s/step - loss: 0.5615 - accuracy: 0.8361 - val_loss: 0.6200 - val_accuracy: 0.8219
Epoch 124/170
5/5 [==============================] - 9s 2s/step - loss: 0.5399 - accuracy: 0.8435 - val_loss: 0.6044 - val_accuracy: 0.8246
Epoch 125/170
5/5 [==============================] - 9s 2s/step - loss: 0.5478 - accuracy: 0.8413 - val_loss: 0.6201 - val_accuracy: 0.8228
Epoch 126/170
5/5 [==============================] - 9s 2s/step - loss: 0.5403 - accuracy: 0.8431 - val_loss: 0.5931 - val_accuracy: 0.8291
Epoch 127/170
5/5 [==============================] - 9s 2s/step - loss: 0.5615 - accuracy: 0.8385 - val_loss: 0.6153 - val_accuracy: 0.8253
Epoch 128/170
5/5 [==============================] - 9s 2s/step - loss: 0.5613 - accuracy: 0.8388 - val_loss: 0.6518 - val_accuracy: 0.8109
Epoch 129/170
5/5 [==============================] - 10s 2s/step - loss: 0.5501 - accuracy: 0.8404 - val_loss: 0.5955 - val_accuracy: 0.8300
Epoch 130/170
5/5 [==============================] - 9s 2s/step - loss: 0.5287 - accuracy: 0.8478 - val_loss: 0.6034 - val_accuracy: 0.8284
Epoch 131/170
5/5 [==============================] - 9s 2s/step - loss: 0.5446 - accuracy: 0.8415 - val_loss: 0.6497 - val_accuracy: 0.8151
Epoch 132/170
5/5 [==============================] - 9s 2s/step - loss: 0.5567 - accuracy: 0.8400 - val_loss: 0.5922 - val_accuracy: 0.8305
Epoch 133/170
5/5 [==============================] - 9s 2s/step - loss: 0.5433 - accuracy: 0.8426 - val_loss: 0.6224 - val_accuracy: 0.8170
Epoch 134/170
5/5 [==============================] - 9s 2s/step - loss: 0.5250 - accuracy: 0.8488 - val_loss: 0.5738 - val_accuracy: 0.8354
Epoch 135/170
5/5 [==============================] - 9s 2s/step - loss: 0.5322 - accuracy: 0.8483 - val_loss: 0.6809 - val_accuracy: 0.8047
Epoch 136/170
5/5 [==============================] - 9s 2s/step - loss: 0.5709 - accuracy: 0.8342 - val_loss: 0.5995 - val_accuracy: 0.8260
Epoch 137/170
5/5 [==============================] - 9s 2s/step - loss: 0.5189 - accuracy: 0.8505 - val_loss: 0.5883 - val_accuracy: 0.8315
Epoch 138/170
5/5 [==============================] - 9s 2s/step - loss: 0.5218 - accuracy: 0.8493 - val_loss: 0.5813 - val_accuracy: 0.8317
Epoch 139/170
5/5 [==============================] - 10s 2s/step - loss: 0.5202 - accuracy: 0.8506 - val_loss: 0.5746 - val_accuracy: 0.8334
Epoch 140/170
5/5 [==============================] - 9s 2s/step - loss: 0.5169 - accuracy: 0.8518 - val_loss: 0.6283 - val_accuracy: 0.8209
Epoch 141/170
5/5 [==============================] - 9s 2s/step - loss: 0.5340 - accuracy: 0.8456 - val_loss: 0.5945 - val_accuracy: 0.8282
Epoch 142/170
5/5 [==============================] - 9s 2s/step - loss: 0.5269 - accuracy: 0.8481 - val_loss: 0.5641 - val_accuracy: 0.8369
Epoch 143/170
5/5 [==============================] - 9s 2s/step - loss: 0.5095 - accuracy: 0.8524 - val_loss: 0.5770 - val_accuracy: 0.8329
Epoch 144/170
5/5 [==============================] - 9s 2s/step - loss: 0.5081 - accuracy: 0.8549 - val_loss: 0.5628 - val_accuracy: 0.8360
Epoch 145/170
5/5 [==============================] - 9s 2s/step - loss: 0.5609 - accuracy: 0.8381 - val_loss: 0.6175 - val_accuracy: 0.8201
Epoch 146/170
5/5 [==============================] - 9s 2s/step - loss: 0.5167 - accuracy: 0.8517 - val_loss: 0.5757 - val_accuracy: 0.8313
Epoch 147/170
5/5 [==============================] - 10s 2s/step - loss: 0.5087 - accuracy: 0.8544 - val_loss: 0.5579 - val_accuracy: 0.8372
Epoch 148/170
5/5 [==============================] - 9s 2s/step - loss: 0.5063 - accuracy: 0.8548 - val_loss: 0.5601 - val_accuracy: 0.8374
Epoch 149/170
5/5 [==============================] - 9s 2s/step - loss: 0.4958 - accuracy: 0.8573 - val_loss: 0.5644 - val_accuracy: 0.8328
Epoch 150/170
5/5 [==============================] - 9s 2s/step - loss: 0.5111 - accuracy: 0.8532 - val_loss: 0.5520 - val_accuracy: 0.8380
Epoch 151/170
5/5 [==============================] - 11s 2s/step - loss: 0.5018 - accuracy: 0.8559 - val_loss: 0.5776 - val_accuracy: 0.8290
Epoch 152/170
5/5 [==============================] - 9s 2s/step - loss: 0.4980 - accuracy: 0.8568 - val_loss: 0.5622 - val_accuracy: 0.8356
Epoch 153/170
5/5 [==============================] - 10s 2s/step - loss: 0.4985 - accuracy: 0.8567 - val_loss: 0.5496 - val_accuracy: 0.8391
Epoch 154/170
5/5 [==============================] - 9s 2s/step - loss: 0.5075 - accuracy: 0.8543 - val_loss: 0.5919 - val_accuracy: 0.8320
Epoch 155/170
5/5 [==============================] - 9s 2s/step - loss: 0.5179 - accuracy: 0.8504 - val_loss: 0.5880 - val_accuracy: 0.8290
Epoch 156/170
5/5 [==============================] - 9s 2s/step - loss: 0.4970 - accuracy: 0.8578 - val_loss: 0.5459 - val_accuracy: 0.8403
Epoch 157/170
5/5 [==============================] - 9s 2s/step - loss: 0.4883 - accuracy: 0.8604 - val_loss: 0.5617 - val_accuracy: 0.8348
Epoch 158/170
5/5 [==============================] - 9s 2s/step - loss: 0.4847 - accuracy: 0.8609 - val_loss: 0.5320 - val_accuracy: 0.8433
Epoch 159/170
5/5 [==============================] - 10s 2s/step - loss: 0.5355 - accuracy: 0.8450 - val_loss: 0.6657 - val_accuracy: 0.8127
Epoch 160/170
5/5 [==============================] - 9s 2s/step - loss: 0.5074 - accuracy: 0.8554 - val_loss: 0.5521 - val_accuracy: 0.8410
Epoch 161/170
5/5 [==============================] - 9s 2s/step - loss: 0.4858 - accuracy: 0.8600 - val_loss: 0.5583 - val_accuracy: 0.8372
Epoch 162/170
5/5 [==============================] - 9s 2s/step - loss: 0.4856 - accuracy: 0.8611 - val_loss: 0.5428 - val_accuracy: 0.8413
Epoch 163/170
5/5 [==============================] - 11s 2s/step - loss: 0.4849 - accuracy: 0.8614 - val_loss: 0.5389 - val_accuracy: 0.8403
Epoch 164/170
5/5 [==============================] - 9s 2s/step - loss: 0.4930 - accuracy: 0.8577 - val_loss: 0.5670 - val_accuracy: 0.8327
Epoch 165/170
5/5 [==============================] - 9s 2s/step - loss: 0.4909 - accuracy: 0.8596 - val_loss: 0.5349 - val_accuracy: 0.8416
Epoch 166/170
5/5 [==============================] - 9s 2s/step - loss: 0.4762 - accuracy: 0.8637 - val_loss: 0.5526 - val_accuracy: 0.8376
Epoch 167/170
5/5 [==============================] - 9s 2s/step - loss: 0.4810 - accuracy: 0.8624 - val_loss: 0.5423 - val_accuracy: 0.8419
Epoch 168/170
5/5 [==============================] - 9s 2s/step - loss: 0.4747 - accuracy: 0.8651 - val_loss: 0.5307 - val_accuracy: 0.8439
Epoch 169/170
5/5 [==============================] - 9s 2s/step - loss: 0.4769 - accuracy: 0.8636 - val_loss: 0.5726 - val_accuracy: 0.8330
Epoch 170/170
5/5 [==============================] - 10s 2s/step - loss: 0.5019 - accuracy: 0.8568 - val_loss: 0.5472 - val_accuracy: 0.8379

Evaluate the Model

After training, you will want to see how your model is doing on a test set. For segmentation models, you can use the intersection-over-union and the dice score as metrics to evaluate your model. You'll see how it is implemented in this section.

In [23]:
def get_images_and_segments_test_arrays():
  '''
  Gets a subsample of the val set as your test set

  Returns:
    Test set containing ground truth images and label maps
  '''
  y_true_segments = []
  y_true_images = []
  test_count = 64

  ds = validation_dataset.unbatch()
  ds = ds.batch(101)

  for image, annotation in ds.take(1):
    y_true_images = image
    y_true_segments = annotation


  y_true_segments = y_true_segments[:test_count, : ,: , :]
  y_true_segments = np.argmax(y_true_segments, axis=3)  

  return y_true_images, y_true_segments

# load the ground truth images and segmentation masks
y_true_images, y_true_segments = get_images_and_segments_test_arrays()

Make Predictions

You can get output segmentation masks by using the predict() method. As you may recall, the output of our segmentation model has the shape (height, width, 12) where 12 is the number of classes. Each pixel value in those 12 slices indicates the probability of that pixel belonging to that particular class. If you want to create the predicted label map, then you can get the argmax() of that axis. This is shown in the following cell.

In [24]:
# get the model prediction
results = model.predict(validation_dataset, steps=validation_steps)

# for each pixel, get the slice number which has the highest probability
results = np.argmax(results, axis=3)
1/1 [==============================] - 1s 821ms/step

Compute Metrics

The function below generates the IOU and dice score of the prediction and ground truth masks. From the lectures, it is given that:

$$IOU = \frac{area\_of\_overlap}{area\_of\_union}$$
$$Dice Score = 2 * \frac{area\_of\_overlap}{combined\_area}$$

The code below does that for you. A small smoothening factor is introduced in the denominators to prevent possible division by zero.

In [25]:
def compute_metrics(y_true, y_pred):
  '''
  Computes IOU and Dice Score.

  Args:
    y_true (tensor) - ground truth label map
    y_pred (tensor) - predicted label map
  '''
  
  class_wise_iou = []
  class_wise_dice_score = []

  smoothening_factor = 0.00001

  for i in range(12):
    intersection = np.sum((y_pred == i) * (y_true == i))
    y_true_area = np.sum((y_true == i))
    y_pred_area = np.sum((y_pred == i))
    combined_area = y_true_area + y_pred_area
    
    iou = (intersection + smoothening_factor) / (combined_area - intersection + smoothening_factor)
    class_wise_iou.append(iou)
    
    dice_score =  2 * ((intersection + smoothening_factor) / (combined_area + smoothening_factor))
    class_wise_dice_score.append(dice_score)

  return class_wise_iou, class_wise_dice_score

Show Predictions and Metrics

You can now see the predicted segmentation masks side by side with the ground truth. The metrics are also overlayed so you can evaluate how your model is doing.

In [26]:
# input a number from 0 to 63 to pick an image from the test set
integer_slider = 0

# compute metrics
iou, dice_score = compute_metrics(y_true_segments[integer_slider], results[integer_slider])  

# visualize the output and metrics
show_predictions(y_true_images[integer_slider], [results[integer_slider], y_true_segments[integer_slider]], ["Image", "Predicted Mask", "True Mask"], iou, dice_score)

Display Class Wise Metrics

You can also compute the class-wise metrics so you can see how your model performs across all images in the test set.

In [27]:
# compute class-wise metrics
cls_wise_iou, cls_wise_dice_score = compute_metrics(y_true_segments, results)
In [28]:
# print IOU for each class
for idx, iou in enumerate(cls_wise_iou):
  spaces = ' ' * (13-len(class_names[idx]) + 2)
  print("{}{}{} ".format(class_names[idx], spaces, iou)) 
sky            0.884328597649985 
building       0.7578834162759777 
column/pole    4.825323294332502e-10 
road           0.8928438306067477 
side walk      0.6061514128872904 
vegetation     0.8663567330281644 
traffic light  3.134206731293736e-10 
fence          0.049007125981091956 
vehicle        0.3804444278192016 
pedestrian     0.03281799887991064 
byciclist      0.0031137938103221827 
void           0.0915496684870582 
In [29]:
# print the dice score for each class
for idx, dice_score in enumerate(cls_wise_dice_score):
  spaces = ' ' * (13-len(class_names[idx]) + 2)
  print("{}{}{} ".format(class_names[idx], spaces, dice_score)) 
sky            0.9386139962721164 
building       0.862268122292149 
column/pole    9.650646588665004e-10 
road           0.943388795392842 
side walk      0.7547873855906326 
vegetation     0.9283935034550277 
traffic light  6.268413462587472e-10 
fence          0.09343525848729844 
vehicle        0.5511912253323258 
pedestrian     0.06355040080259163 
byciclist      0.006208256391236611 
void           0.1677425611265993 

That's all for this lab! In the next section, you will work on another architecture for building a segmentation model: the UNET.